import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# network in regression
class Net(nn.Module):
    def __init__(self, input_size):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

# feature representation
class Fea(nn.Module):
    def __init__(self, input_size):
        super(Fea, self).__init__()
        self.fc1 = nn.Linear(input_size, 200)
        self.fc2 = nn.Linear(200, 200)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)

        return x


class Reg(nn.Module):
    def __init__(self, input_size=200):
        super(Reg, self).__init__()
        self.fc = nn.Linear(input_size, 1)

    def forward(self, x):
        x = self.fc(x)
        return x

class FeaReg(nn.Module):
    def __init__(self, fea, reg0, reg1 = None):
        super(FeaReg, self).__init__()
        self.fea = fea
        self.reg0 = reg0
        self.reg1 = reg1

    def forward(self, x):
        x = self.fea(x)
        if self.reg1 != None:
            x = 0.5*(self.reg0(x) +self.reg1(x))
        else:
            x = self.reg0(x)
        return x

class SqueezeLastTwo(nn.Module):
    """A module which squeezes the last two dimensions, ordinary squeeze can be a problem for batch size 1"""

    def __init__(self):
        super(SqueezeLastTwo, self).__init__()

    def forward(self, x):
        return x.view(x.shape[0], x.shape[1])


